#!/usr/bin/env python3
# Copyright (C) 2023 Checkmk GmbH - License: GNU General Public License v2
# This file is part of Checkmk (https://checkmk.com). It is subject to the terms and
# conditions defined in the file COPYING, which is part of this source code package.
"""Middleware support for conditional request profiling in WSGI applications.

Enables runtime toggling of performance profiling based on configurable triggers, such as specific
query parameters or configuration settings.

This module has two usable profiling middlewares:

    DirectWrappingProfilingMiddleware
        directly wraps an already imported WSGI application

    LazyImportProfilingMiddleware
        lazily imports and wraps a WSGI application generated by a factory function, given by
        its dotted module-path.
"""

#  NOTE:
#      The location of this file is intentionally under cmk.utils
#      and not cmk.gui.wsgi, in order to prevent the import of cmk
#      before the debugger has a chance to run.

from __future__ import annotations

import abc
import cProfile
import dataclasses
import importlib
import logging
import pathlib
import pstats
import threading
import typing
import urllib.parse
from wsgiref.types import StartResponse, WSGIApplication, WSGIEnvironment

import pyprof2calltree

P = typing.ParamSpec("P")

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class ProfileSetting:
    mode: typing.Literal[True, False, "enable_by_var"]
    cachegrind_file: pathlib.Path
    profile_file: pathlib.Path
    accumulate: bool = False  # reset on every request, or accumulate over multiple requests
    discard_first_request: bool = False


ProfileConfigMode = typing.Literal["default", "imported"]


class ProfileConfigLoader:
    """Lazy load profiling configuration

    This class allows switching between two modes. In the `default` mode, a default configuration
    is returned without having to import the config module (and triggering hundreds of subsequent
    imports). In the `imported` mode the loader fetches the config by calling the `fetch_config`
    callable.

    The fetching functions are passed into the constructor for adhering to dependency injection
    principles which allow for better testability.

    """

    def __init__(
        self,
        *,
        fetch_actual_config: typing.Callable[[], ProfileSetting],
        fetch_default_config: typing.Callable[[], ProfileSetting],
    ) -> None:
        self.mode: ProfileConfigMode = "default"
        self.fetch_default_config = fetch_default_config
        self.fetch_config = fetch_actual_config

    def __call__(self) -> ProfileSetting:
        if self.mode == "default":
            logger.debug("Loaded default config")
            return self.fetch_default_config()

        if self.mode == "imported":
            logger.debug("Loaded imported config")
            return self.fetch_config()

        raise RuntimeError(f"Unknown mode: {self.mode}")

    def switch_mode(self, mode: ProfileConfigMode) -> None:
        self.mode = mode


class ProfilingMiddleware(abc.ABC):
    """Base class of profiling middleware.

    The `load_app` method needs to be implemented in order to achieve the desired result.

    Attributes:
        config_loader (ProfileConfigLoader): Stateful loader for the profiling configuration.

    Args:
        config_loader (ProfileConfigLoader): Stateful loader for the profiling configuration.

    """

    def __init__(self, config_loader: ProfileConfigLoader) -> None:
        self.config_loader = config_loader

        self._first_request = True
        self._profiler: cProfile.Profile = cProfile.Profile()
        self._profile_setting: ProfileSetting = config_loader()
        self._thread_lock = threading.Lock()
        self._wsgi_app: WSGIApplication | None = None

    @abc.abstractmethod
    def load_app(self) -> WSGIApplication:
        raise NotImplementedError

    def __call__(
        self,
        environ: WSGIEnvironment,
        start_response: StartResponse,
    ) -> typing.Iterable[bytes]:
        discard_first_request = self._first_request and self._profile_setting.discard_first_request
        if not self.should_profile(environ) or discard_first_request:
            logger.debug("Won't profile")
            if self._wsgi_app is None:
                self._wsgi_app = self.load_app()
            response = self._wsgi_app(environ, start_response)
            return response

        with self._thread_lock:
            self._reset_profiler()

            # ENABLE: Profiling
            self._profiler.enable()

            if self._wsgi_app is None:
                self._wsgi_app = self.load_app()

            try:
                response = self._wsgi_app(environ, start_response)
            finally:
                self._profiler.disable()
                # DISABLE: Profiling
                self._save_profile_data()

            return response

    def _reset_profiler(self):
        if self._profile_setting.accumulate:
            return

        logger.debug("Resetting profiler")
        self._profiler = cProfile.Profile()

    def should_profile(self, environ: WSGIEnvironment) -> bool:
        # Fetch the current configuration.
        self._profile_setting = self.config_loader()

        if self._profile_setting.mode is True:
            logger.debug("Profile mode is True")
            return True
        if self._profile_setting.mode is False:
            logger.debug("Profile mode is False")
            return False

        def is_truthy_query_param(query_string: str, *, param: str) -> bool:
            if not query_string:
                return False
            truthy_values = {"1", "t", "true", "y", "yes", "on"}
            params = urllib.parse.parse_qs(query_string)
            return param in params and any(
                value.lower() in truthy_values for value in params[param]
            )

        # enable_by_var case
        return is_truthy_query_param(environ.get("QUERY_STRING", ""), param="_profile")

    def _save_profile_data(self) -> None:
        logger.debug("Saving profiling data")

        profile_file = self._profile_setting.profile_file
        script_file = profile_file.with_suffix(".py")

        self._profiler.dump_stats(str(profile_file))
        stats = pstats.Stats(self._profiler)
        conv = pyprof2calltree.CalltreeConverter(stats)
        with open(self._profile_setting.cachegrind_file, "w") as grind:
            conv.output(grind)

        # We don't want to overwrite manual changes to the script file, so we only create it.
        if not script_file.exists():
            with script_file.open("w", encoding="utf-8") as f:
                f.write(
                    "#!/usr/bin/env python3\n"
                    "import pstats\n"
                    f'stats = pstats.Stats("{profile_file}")\n'
                    "stats.sort_stats('cumtime').print_stats()\n"
                )
            script_file.chmod(0o755)


class DirectWrappingProfilingMiddleware(ProfilingMiddleware):
    """Wraps a WSGI application to add optional profiling of requests.

    Attributes:
        config_loader (ProfileConfigLoader): Stateful loader for the profiling configuration.

    Args:
        app_to_wrap (WSGIApplication): The WSGI application to wrap for profiling.

    """

    def __init__(
        self,
        app_to_wrap: WSGIApplication,
        config_loader: ProfileConfigLoader,
    ):
        self._app_to_wrap = app_to_wrap
        config_loader.switch_mode("imported")
        super().__init__(config_loader=config_loader)

    def load_app(self) -> WSGIApplication:
        return self._app_to_wrap


class LazyImportProfilingMiddleware(ProfilingMiddleware):
    """Wraps a WSGI application imported upon first request to add optional profiling of requests.

    To enable profiling including the imports, this middleware accepts the dotted module path and
    factory name of a WSGI application to defer the import and creation of the application instance
    until the first request is received. Profiling, if enabled, will include the import phase.

    NOTE: The imports then need to happen inside a function though.

    Attributes:
        config_loader (ProfileConfigLoader): Stateful loader for the profiling configuration.

    Args:
        app_factory_module (str): Dotted path to the module containing the WSGI application factory.
        app_factory_name (str): Name of the factory function in that module to be called to get the WSGI application.
        app_factory_args (tuple[Any, ...]): The positional arguments to be passed to the app factory.
        app_factory_kwargs (dict[str, Any]): The keyword arguments to be passed to the app factory.
        config_loader (ProfileConfigLoader): Stateful loader for the profiling configuration.

    """

    def __init__(
        self,
        app_factory_module: str,
        app_factory_name: str,
        app_factory_args: tuple[typing.Any, ...],
        app_factory_kwargs: dict[str, typing.Any],
        config_loader: ProfileConfigLoader,
    ) -> None:
        self.app_factory_module = app_factory_module
        self.app_factory_name = app_factory_name
        self.app_factory_args = app_factory_args
        self.app_factory_kwargs = app_factory_kwargs
        super().__init__(config_loader=config_loader)

    def load_app(self) -> WSGIApplication:
        module = importlib.import_module(self.app_factory_module)
        app_factory = getattr(module, self.app_factory_name)
        # Tell the configloader that we've done importing, and it can fetch the real config, which
        # potentially includes importing modules.
        self.config_loader.switch_mode("imported")
        logger.debug("Loaded application and switched to imported mode.")
        return app_factory(*self.app_factory_args, **self.app_factory_kwargs)


__all__ = [
    "DirectWrappingProfilingMiddleware",
    "LazyImportProfilingMiddleware",
    "ProfileSetting",
    "ProfileConfigLoader",
]
